{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
        "\n",
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ],
      "metadata": {
        "id": "OsFaZscKSPvo"
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Run ML inference with multiple differently-trained models\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.sandbox.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/per_key_models.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/per_key_models.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>\n"
      ],
      "metadata": {
        "id": "ZUSiAR62SgO8"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Running inference with multiple differently-trained models performing the same task is useful in many scenarios, including the following examples:\n",
        "\n",
        "* You want to compare the performance of multiple different models.\n",
        "* You have models trained on different datasets that you want to use conditionally based on additional metadata.\n",
        "\n",
        "In Apache Beam, the recommended way to run inference is to use the `RunInference` transform. By using a `KeyedModelHandler`, you can efficiently run inference with O(100s) of models without having to manage memory yourself.\n",
        "\n",
        "This notebook demonstrates how to use a `KeyedModelHandler` to run inference in an Apache Beam pipeline with multiple different models on a per-key basis. This notebook uses pretrained pipelines from Hugging Face. Before continuing with this notebook, it is recommended that you walk through the [Use RunInference in Apache Beam](https://colab.sandbox.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_pytorch_tensorflow_sklearn.ipynb) notebook."
      ],
      "metadata": {
        "id": "ZAVOrrW2An1n"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Install dependencies\n",
        "\n",
        "Install both Apache Beam and the dependencies needed by Hugging Face."
      ],
      "metadata": {
        "id": "_fNyheQoDgGt"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "B-ENznuJqArA",
        "outputId": "f72963fc-82db-4d0d-9225-07f6b501e256"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            ""
          ]
        }
      ],
      "source": [
        "!pip install apache_beam[interactive,gcp]>=2.51.0 --quiet\n",
        "!pip install torch --quiet\n",
        "!pip install transformers --quiet\n",
        "\n",
        "# To use the newly installed versions, restart the runtime.\n",
        "exit()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from typing import Dict\n",
        "from typing import Iterable\n",
        "from typing import Tuple\n",
        "\n",
        "from transformers import pipeline\n",
        "\n",
        "import apache_beam as beam\n",
        "from apache_beam.ml.inference.base import KeyedModelHandler\n",
        "from apache_beam.ml.inference.base import KeyModelMapping\n",
        "from apache_beam.ml.inference.base import PredictionResult\n",
        "from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler\n",
        "from apache_beam.ml.inference.base import RunInference"
      ],
      "metadata": {
        "id": "wUmBEglvsOYW"
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Define the model configurations\n",
        "\n",
        "A model handler is the Apache Beam method used to define the configuration needed to load and invoke models. Because this example uses two models, we define two model handlers, one for each model. Because both models are incapsulated within Hugging Face pipelines, we use the model handler `HuggingFacePipelineModelHandler`.\n",
        "\n",
        "For this example, load the models using Hugging Face, and then run them against an example. The models produce different outputs."
      ],
      "metadata": {
        "id": "uEqljVgCD7hx"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "distilbert_mh = HuggingFacePipelineModelHandler('text-classification', model=\"distilbert-base-uncased-finetuned-sst-2-english\")\n",
        "roberta_mh = HuggingFacePipelineModelHandler('text-classification', model=\"roberta-large-mnli\")\n",
        "\n",
        "distilbert_pipe = pipeline('text-classification', model=\"distilbert-base-uncased-finetuned-sst-2-english\")\n",
        "roberta_large_pipe = pipeline(model=\"roberta-large-mnli\")"
      ],
      "metadata": {
        "id": "v2NJT5ZcxgH5",
        "outputId": "3924d72e-5c49-477d-c50f-6d9098f5a4b2"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)lve/main/config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "b7cb51663677434ca42de6b5e6f37420"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "3702756019854683a9dea9f8af0a29d0"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)okenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "52b9fdb51d514c2e8b9fa5813972ab01"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "eca24b7b7b1847c1aed6aa59a44ed63a"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)lve/main/config.json:   0%|          | 0.00/688 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "4d4cfe9a0ca54897aa991420bee01ff9"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading model.safetensors:   0%|          | 0.00/1.43G [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "aee85cd919d24125acff1663fba0b47c"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "0af8ad4eed2d49878fa83b5828d58a96"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "1ed943a51c704ab7a72101b5b6182772"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "5b1dcbb533894267b184fd591d8ccdbc"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "distilbert_pipe(\"This restaurant is awesome\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "H3nYX9thy8ec",
        "outputId": "826e3285-24b9-47a8-d2a6-835543fdcae7"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[{'label': 'POSITIVE', 'score': 0.9998743534088135}]"
            ]
          },
          "metadata": {},
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "roberta_large_pipe(\"This restaurant is awesome\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IIfc94ODyjUg",
        "outputId": "94ec8afb-ebfb-47ce-9813-48358741bc6b"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[{'label': 'NEUTRAL', 'score': 0.7313134670257568}]"
            ]
          },
          "metadata": {},
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Define the examples\n",
        "\n",
        "Define examples to input into the pipeline. The examples include the correct classifications."
      ],
      "metadata": {
        "id": "yd92MC7YEsTf"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "examples = [\n",
        "    (\"This restaurant is awesome\", \"positive\"),\n",
        "    (\"This restaurant is bad\", \"negative\"),\n",
        "    (\"I feel fine\", \"neutral\"),\n",
        "    (\"I love chocolate\", \"positive\"),\n",
        "]"
      ],
      "metadata": {
        "id": "5HAziWEavQws"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "To feed the examples into RunInference, you need distinct keys that can map to the model. In this case, to make it possible to extract the actual sentiment of the example later, define keys in the form `<model_name>-<actual_sentiment>`."
      ],
      "metadata": {
        "id": "r6GXL5PLFBY7"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class FormatExamples(beam.DoFn):\n",
        "  \"\"\"\n",
        "  Map each example to a tuple of ('<model_name>-<actual_sentiment>', 'example').\n",
        "  Use these keys to map our elements to the correct models.\n",
        "  \"\"\"\n",
        "  def process(self, element: Tuple[str, str]) -> Iterable[Tuple[str, str]]:\n",
        "    yield (f'distilbert-{element[1]}', element[0])\n",
        "    yield (f'roberta-{element[1]}', element[0])"
      ],
      "metadata": {
        "id": "p2uVwws8zRpg"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Use the formatted keys to define a `KeyedModelHandler` that maps keys to the `ModelHandler` used for those keys. The `KeyedModelHandler` method lets you define an optional `max_models_per_worker_hint`, which limits the number of models that can be held in a single worker process at one time. If your worker might run out of memory, use this option. For more information about managing memory, see [Use a keyed ModelHandler](https://beam.apache.org/documentation/ml/about-ml/#use-a-keyed-modelhandler-object)."
      ],
      "metadata": {
        "id": "IP65_5nNGIb8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "per_key_mhs = [\n",
        "    KeyModelMapping(['distilbert-positive', 'distilbert-neutral', 'distilbert-negative'], distilbert_mh),\n",
        "    KeyModelMapping(['roberta-positive', 'roberta-neutral', 'roberta-negative'], roberta_mh)\n",
        "]\n",
        "mh = KeyedModelHandler(per_key_mhs, max_models_per_worker_hint=2)"
      ],
      "metadata": {
        "id": "DZpfjeGL2hMG"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Postprocess the results\n",
        "\n",
        "The `RunInference` transform returns a tuple that contains the following objects:\n",
        "* the original key\n",
        "* a `PredictionResult` object containing the original example and the inference\n",
        "Use those outputs to extract the relevant data. Then, to compare each model's prediction, group this data by the original example."
      ],
      "metadata": {
        "id": "_a4ZmnD5FSeG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class ExtractResults(beam.DoFn):\n",
        "  \"\"\"\n",
        "  Extract the relevant data from the PredictionResult object.\n",
        "  \"\"\"\n",
        "  def process(self, element: Tuple[str, PredictionResult]) -> Iterable[Tuple[str, Dict[str, str]]]:\n",
        "    actual_sentiment = element[0].split('-')[1]\n",
        "    model = element[0].split('-')[0]\n",
        "    result = element[1]\n",
        "    example = result.example\n",
        "    predicted_sentiment = result.inference[0]['label']\n",
        "\n",
        "    yield (example, {'model': model, 'actual_sentiment': actual_sentiment, 'predicted_sentiment': predicted_sentiment})"
      ],
      "metadata": {
        "id": "FOwFNQA053TG"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Finally, print the results produced by each model."
      ],
      "metadata": {
        "id": "JVnv4gGbFohk"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class PrintResults(beam.DoFn):\n",
        "  \"\"\"\n",
        "  Print the results produced by each model along with the actual sentiment.\n",
        "  \"\"\"\n",
        "  def process(self, element: Tuple[str, Iterable[Dict[str, str]]]):\n",
        "    example = element[0]\n",
        "    actual_sentiment = element[1][0]['actual_sentiment']\n",
        "    predicted_sentiment_1 = element[1][0]['predicted_sentiment']\n",
        "    model_1 = element[1][0]['model']\n",
        "    predicted_sentiment_2 = element[1][1]['predicted_sentiment']\n",
        "    model_2 = element[1][1]['model']\n",
        "\n",
        "    if model_1 == 'distilbert':\n",
        "      distilbert_prediction = predicted_sentiment_1\n",
        "      roberta_prediction = predicted_sentiment_2\n",
        "    else:\n",
        "      roberta_prediction = predicted_sentiment_1\n",
        "      distilbert_prediction = predicted_sentiment_2\n",
        "\n",
        "    print(f'Example: {example}\\nActual Sentiment: {actual_sentiment}\\n'\n",
        "          f'Distilbert Prediction: {distilbert_prediction}\\n'\n",
        "          f'Roberta Prediction: {roberta_prediction}\\n------------')"
      ],
      "metadata": {
        "id": "kUQJNYOa9Q5-"
      },
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Run the pipeline\n",
        "\n",
        "To run a single Apache Beam pipeline, combine the previous steps."
      ],
      "metadata": {
        "id": "-LrpmM2PGAkf"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "with beam.Pipeline() as beam_pipeline:\n",
        "\n",
        "  formatted_examples = (\n",
        "            beam_pipeline\n",
        "            | \"ReadExamples\" >> beam.Create(examples)\n",
        "            | \"FormatExamples\" >> beam.ParDo(FormatExamples()))\n",
        "  inferences = (\n",
        "            formatted_examples\n",
        "            | \"Run Inference\" >> RunInference(mh)\n",
        "            | \"ExtractResults\" >> beam.ParDo(ExtractResults())\n",
        "            | \"GroupByExample\" >> beam.GroupByKey()\n",
        "  )\n",
        "\n",
        "  inferences | beam.ParDo(PrintResults())\n",
        "\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 463
        },
        "id": "B9Wti3XH0Iqe",
        "outputId": "528ad732-ecf8-4877-ab6a-badad7944fed"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "\n",
              "        if (typeof window.interactive_beam_jquery == 'undefined') {\n",
              "          var jqueryScript = document.createElement('script');\n",
              "          jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
              "          jqueryScript.type = 'text/javascript';\n",
              "          jqueryScript.onload = function() {\n",
              "            var datatableScript = document.createElement('script');\n",
              "            datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
              "            datatableScript.type = 'text/javascript';\n",
              "            datatableScript.onload = function() {\n",
              "              window.interactive_beam_jquery = jQuery.noConflict(true);\n",
              "              window.interactive_beam_jquery(document).ready(function($){\n",
              "                \n",
              "              });\n",
              "            }\n",
              "            document.head.appendChild(datatableScript);\n",
              "          };\n",
              "          document.head.appendChild(jqueryScript);\n",
              "        } else {\n",
              "          window.interactive_beam_jquery(document).ready(function($){\n",
              "            \n",
              "          });\n",
              "        }"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Example: This restaurant is awesome\n",
            "Actual Sentiment: positive\n",
            "Distilbert Prediction: POSITIVE\n",
            "Roberta Prediction: NEUTRAL\n",
            "------------\n",
            "Example: This restaurant is bad\n",
            "Actual Sentiment: negative\n",
            "Distilbert Prediction: NEGATIVE\n",
            "Roberta Prediction: NEUTRAL\n",
            "------------\n",
            "Example: I love chocolate\n",
            "Actual Sentiment: positive\n",
            "Distilbert Prediction: POSITIVE\n",
            "Roberta Prediction: NEUTRAL\n",
            "------------\n",
            "Example: I feel fine\n",
            "Actual Sentiment: neutral\n",
            "Distilbert Prediction: POSITIVE\n",
            "Roberta Prediction: ENTAILMENT\n",
            "------------\n"
          ]
        }
      ]
    }
  ]
}
